/**
 * Copyright 2021 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#include "execute_graph.h"
#include "execute_graph_builder.h"
#include "execute_node.h"
#include "op_desc.h"
#include "ge/infer_shape_functions.h"
#include "ge/infer_shape_mat_mul.h"

#include <gtest/gtest.h>
#include "common_reg.h"
namespace ge {
namespace {
bool Connected(ExecuteGraph *graph, ExecuteNode *src_node, int src_index, ExecuteNode *dst_node, int dst_index) {
  for (size_t i = 0; i < graph->GetEdgeCount(); ++i) {
    auto edge = graph->GetEdgeByIndex(i);
    if (edge->src_node == src_node && edge->src_index == src_index && edge->dst_node == dst_node &&
        edge->dst_index == dst_index) {
      return true;
    }
  }
  return false;
}
bool Connected(ExecuteGraph *graph, const char *src_node, int src_index, const char *dst_node, int dst_index) {
  return Connected(graph, graph->FindNodeByName(src_node), src_index, graph->FindNodeByName(dst_node), dst_index);
}
bool SameTensor(ExecuteNode *src_node, int src_index, ExecuteNode *dst_node, int dst_index) {
  return src_node->GetOpDesc()->MutableOutputDesc(src_index) == dst_node->GetOpDesc()->MutableInputDesc(dst_index) &&
      src_node->GetOpDesc()->MutableOutputDesc(src_index) != nullptr;
}
bool SameTensor(ExecuteGraph *graph, const char *src_node, int src_index, const char *dst_node, int dst_index) {
  return SameTensor(graph->FindNodeByName(src_node), src_index, graph->FindNodeByName(dst_node), dst_index);
}

std::vector<std::string> GetNodeNamesByTopo(ExecuteGraph *graph) {
  std::vector<std::string> names;
  for (size_t i = 0; i < graph->GetNodeCount(); ++i) {
    names.emplace_back(graph->GetNodeByIndex(i)->GetOpDesc()->GetName());
  }
  return names;
}

PoolOffset AddNode(ExecuteGraphBuilder &gb, const char *name, PoolOffset op_def,
                   std::initializer_list<std::string> inputs, std::initializer_list<std::string> outputs) {
  auto node_offset = gb.AddNode();
  gb.GetNodeBuilder(node_offset)->SetName(name).SetOpDef(op_def);
  int i = 0;
  for (auto &input : inputs) {
    gb.GetNodeBuilder(node_offset)->SetInputName(i++, input.c_str());
  }
  i = 0;
  for (auto &output : outputs) {
    gb.GetNodeBuilder(node_offset)->SetOutputName(i++, output.c_str());
  }

  return node_offset;
}
}  // namespace
class ExecuteGraphBuilderUt : public testing::Test {};

TEST_F(ExecuteGraphBuilderUt, AddStringDuplicated) {
  ExecuteGraphBuilder gb;
  PoolOffset pool_s1, pool_s2, pool_s3;
  {
    std::string s1("x");
    std::string s2("Data");
    std::string s3("x");

    pool_s1 = gb.AddString(s1.c_str());
    pool_s2 = gb.AddString(s2.c_str());
    pool_s3 = gb.AddString(s3.c_str());
    bool exists = false;
    gb.AddString(s2.c_str(), exists);
    EXPECT_TRUE(exists);
  }

  EXPECT_NE(pool_s1, kInvalidOffset);
  EXPECT_NE(pool_s2, kInvalidOffset);
  EXPECT_NE(pool_s3, kInvalidOffset);

  EXPECT_EQ(pool_s1, pool_s3);
}

TEST_F(ExecuteGraphBuilderUt, AddOpDefOk) {
  ExecuteGraphBuilder gb;

  auto reshape_def = gb.AddOpDef("Reshape");
  auto merge_def = gb.AddOpDef("Merge");

  EXPECT_NE(reshape_def, kInvalidOffset);
  EXPECT_NE(merge_def, kInvalidOffset);
}

TEST_F(ExecuteGraphBuilderUt, AddOpDefDuplicated) {
  ExecuteGraphBuilder gb;

  auto reshape_def1 = gb.AddOpDef("Reshape");
  auto merge_def = gb.AddOpDef("Merge");
  auto reshape_def2 = gb.AddOpDef("Reshape");

  EXPECT_NE(reshape_def1, kInvalidOffset);
  EXPECT_NE(merge_def, kInvalidOffset);
  EXPECT_EQ(reshape_def1, reshape_def2);
}

/**
 * 构造一个只有一个节点的Graph
 */
TEST_F(ExecuteGraphBuilderUt, BuildSingleOpGraph) {
  ExecuteGraphBuilder gb;
  auto reg = RegOpDefs(gb);
  AddNode(gb, "matmul", reg.matmul, {"x1", "x2"}, {"y"});

  auto graph = gb.Build();
  EXPECT_NE(graph, nullptr);

  EXPECT_EQ(graph->GetNodeCount(), 1);
  EXPECT_EQ(graph->GetEdgeCount(), 0);

  auto node = graph->GetNodeByIndex(0);
  EXPECT_NE(node, nullptr);
  EXPECT_EQ(node->GetInDataNodeByIndex(0), nullptr);
  EXPECT_EQ(node->GetInDataNodeByIndex(1), nullptr);
  EXPECT_NE(node->GetOpDesc(), nullptr);

  auto op_desc = node->GetOpDesc();
  EXPECT_GE(op_desc->GetId(), 0);
  EXPECT_STREQ(op_desc->GetName(), "matmul");
  EXPECT_EQ(op_desc->GetInputDescCount(), 2);
  EXPECT_EQ(op_desc->GetOutputDescCount(), 1);
  EXPECT_STREQ(op_desc->GetInputNameByIndex(0), "x1");
  EXPECT_STREQ(op_desc->GetInputNameByIndex(1), "x2");
  EXPECT_EQ(op_desc->GetInputNameByIndex(2), nullptr);
  EXPECT_STREQ(op_desc->GetOutputNameByIndex(0), "y");
  EXPECT_EQ(op_desc->GetOutputNameByIndex(1), nullptr);

  auto op_def = op_desc->GetOpDef();
  EXPECT_NE(op_def, nullptr);
  EXPECT_STREQ(op_def->type, "MatMul");
  EXPECT_EQ(op_def->infer_shape_func, ops::MatMulInferShape);
  EXPECT_EQ(op_def->tiling_func, TestTilingFunc);
  EXPECT_EQ(op_def->inputs_def.size(), 3);
  EXPECT_EQ(op_def->inputs_def[0].e_type, kMustIo);
  EXPECT_STREQ(op_def->inputs_def[0].name, "x1");
  EXPECT_EQ(op_def->inputs_def[1].e_type, kMustIo);
  EXPECT_STREQ(op_def->inputs_def[1].name, "x2");
  EXPECT_EQ(op_def->inputs_def[2].e_type, kOptionalIo);
  EXPECT_STREQ(op_def->inputs_def[2].name, "bias");
  EXPECT_EQ(op_def->outputs_def.size(), 1);
  EXPECT_EQ(op_def->outputs_def[0].e_type, kMustIo);
  EXPECT_STREQ(op_def->outputs_def[0].name, "y");
}

/**
 * netoutput
 *     |
 *   conv2d
 *   /    \
 * data  const
 */
TEST_F(ExecuteGraphBuilderUt, BuildGraphWithEdges) {
  ExecuteGraphBuilder gb;
  auto reg = RegOpDefs(gb);

  auto data_offset = AddNode(gb, "data", reg.data, {"x"}, {"y"});
  auto const_offset = AddNode(gb, "const", reg.const_offset, {}, {"y"});
  auto conv2d_offset = AddNode(gb, "conv2d", reg.conv2d, {"x", "filter"}, {"y"});
  auto netoutput_offset = AddNode(gb, "netoutput", reg.netoutput, {"x1"}, {"y1"});

  gb.GetNodeBuilder(conv2d_offset)->SetAttr(0, 0).SetAttr(1, 1);

  gb.AddEdge(data_offset, 0, conv2d_offset, 0)
      .AddEdge(const_offset, 0, conv2d_offset, 1)
      .AddEdge(conv2d_offset, 0, netoutput_offset, 0);

  auto graph = gb.Build();
  EXPECT_NE(graph, nullptr);
  EXPECT_EQ(graph->GetNodeCount(), 4);
  EXPECT_EQ(graph->GetEdgeCount(), 3);
  EXPECT_TRUE(Connected(graph.get(), "data", 0, "conv2d", 0));
  EXPECT_TRUE(Connected(graph.get(), "const", 0, "conv2d", 1));
  EXPECT_TRUE(Connected(graph.get(), "conv2d", 0, "netoutput", 0));

  GraphPoolReader pool_reader(graph.get());
  EXPECT_EQ(pool_reader.GetTensorDescCount(), 5);

  EXPECT_TRUE(SameTensor(graph.get(), "data", 0, "conv2d", 0));
  EXPECT_TRUE(SameTensor(graph.get(), "const", 0, "conv2d", 1));
  EXPECT_TRUE(SameTensor(graph.get(), "conv2d", 0, "netoutput", 0));
}

/**
 * netoutput
 *     |
 *   conv2d
 *   /    \
 * data  const
 */
TEST_F(ExecuteGraphBuilderUt, BuildGraphWrongTopoOrder) {
  // 当前要求添加到ExecuteGraphBuild中的节点必须按照拓扑序排列，否则会出错
  ExecuteGraphBuilder gb;
  auto reg = RegOpDefs(gb);

  auto data_offset = AddNode(gb, "data", reg.data, {"x"}, {"y"});
  auto conv2d_offset = AddNode(gb, "conv2d", reg.conv2d, {"x", "filter"}, {"y"});
  auto const_offset = AddNode(gb, "const", reg.const_offset, {}, {"y"});
  auto netoutput_offset = AddNode(gb, "netoutput", reg.netoutput, {"x1"}, {"y1"});

  gb.AddEdge(data_offset, 0, conv2d_offset, 0)
  .AddEdge(const_offset, 0, conv2d_offset, 1)
  .AddEdge(conv2d_offset, 0, netoutput_offset, 0);

  auto graph = gb.Build();
  EXPECT_EQ(graph, nullptr);
}
/**
 * netoutput
 *     |
 *   conv2d
 *   /    \
 * data  const
 */
TEST_F(ExecuteGraphBuilderUt, BuildCorrectOpDesc) {
  ExecuteGraphBuilder gb;
  auto reg = RegOpDefs(gb);

  auto data_offset = AddNode(gb, "data", reg.data, {"x"}, {"y"});
  auto const_offset = AddNode(gb, "const", reg.const_offset, {}, {"y"});
  auto conv2d_offset = AddNode(gb, "conv2d", reg.conv2d, {"x", "filter"}, {"y"});
  auto netoutput_offset = AddNode(gb, "netoutput", reg.netoutput, {"x1"}, {"y1"});

  gb.GetNodeBuilder(conv2d_offset)->SetAttr(0, 0).SetAttr(1, 1);

  gb.AddEdge(data_offset, 0, conv2d_offset, 0)
      .AddEdge(const_offset, 0, conv2d_offset, 1)
      .AddEdge(conv2d_offset, 0, netoutput_offset, 0);

  auto graph = gb.Build();
  EXPECT_NE(graph, nullptr);

  EXPECT_STREQ(graph->GetNodeByIndex(0)->GetOpDesc()->GetType(), "Data");
  EXPECT_STREQ(graph->GetNodeByIndex(1)->GetOpDesc()->GetType(), "Const");
  EXPECT_STREQ(graph->GetNodeByIndex(2)->GetOpDesc()->GetType(), "Conv2D");
  EXPECT_STREQ(graph->GetNodeByIndex(3)->GetOpDesc()->GetType(), "NetOutput");

  auto op_desc = graph->GetNodeByIndex(2)->GetOpDesc();
  EXPECT_NE(op_desc->MutableInputDesc(0), nullptr);
  EXPECT_NE(op_desc->MutableInputDesc(1), nullptr);
  EXPECT_EQ(op_desc->MutableInputDesc(2), nullptr);
  EXPECT_EQ(op_desc->MutableInputDesc(-1), nullptr);
  EXPECT_NE(op_desc->MutableOutputDesc(0), nullptr);
  EXPECT_EQ(op_desc->MutableOutputDesc(1), nullptr);
  EXPECT_EQ(op_desc->MutableOutputDesc(-1), nullptr);

  EXPECT_NE(op_desc->MutableInputDesc("x"), nullptr);
  EXPECT_NE(op_desc->MutableInputDesc("filter"), nullptr);
  EXPECT_EQ(op_desc->MutableInputDesc("something"), nullptr);
  EXPECT_NE(op_desc->MutableOutputDesc("y"), nullptr);
  EXPECT_EQ(op_desc->MutableOutputDesc("hello"), nullptr);

  op_desc = graph->GetNodeByIndex(0)->GetOpDesc();
  EXPECT_TRUE(op_desc->IsTensorDescExists(op_desc->GetInputDesc(0)));
  EXPECT_FALSE(op_desc->IsTensorDescExists(op_desc->GetInputDesc(1)));
  EXPECT_FALSE(op_desc->IsTensorDescExists(op_desc->GetInputDesc(-1)));

  EXPECT_TRUE(op_desc->IsTensorDescExists(op_desc->GetOutputDesc(0)));
  EXPECT_FALSE(op_desc->IsTensorDescExists(op_desc->GetOutputDesc(1)));
  EXPECT_FALSE(op_desc->IsTensorDescExists(op_desc->GetOutputDesc(-1)));

  EXPECT_TRUE(op_desc->IsTensorDescExists(op_desc->GetInputDesc("x")));
  EXPECT_FALSE(op_desc->IsTensorDescExists(op_desc->GetInputDesc("y")));

  EXPECT_TRUE(op_desc->IsTensorDescExists(op_desc->GetOutputDesc("y")));
  EXPECT_FALSE(op_desc->IsTensorDescExists(op_desc->GetOutputDesc("x")));
}
/**
 *    netoutput
 *        |
 *      matmul
 *      /   \
 * trans0    trans1
 *   |         |
 * data0      data1
 */
TEST_F(ExecuteGraphBuilderUt, BuildMatMulGraphOk) {
  ExecuteGraphBuilder gb;
  auto reg = RegOpDefs(gb);

  auto data0 = AddNode(gb, "data0", reg.data, {"x"}, {"y"});
  auto data1 = AddNode(gb, "data1", reg.data, {"x"}, {"y"});
  auto trans0 = AddNode(gb, "trans0", reg.transdata, {"src"}, {"dst"});
  auto trans1 = AddNode(gb, "trans1", reg.transdata, {"src"}, {"dst"});
  auto matmul = AddNode(gb, "matmul", reg.matmul, {"x1", "x2"}, {"y"});
  auto netoutput = AddNode(gb, "netoutput", reg.netoutput, {"x1"}, {"y1"});

  gb.GetNodeBuilder(trans0)->SetAttr(0, std::string("NCHW")).SetAttr(1, std::string("NC1HWC0"));
  gb.GetNodeBuilder(trans1)->SetAttr(0, std::string("NCHW")).SetAttr(1, std::string("FRACTAL_Z"));

  gb.AddEdge(data0, 0, trans0, 0)
      .AddEdge(data1, 0, trans1, 0)
      .AddEdge(trans0, 0, matmul, 0)
      .AddEdge(trans1, 0, matmul, 1)
      .AddEdge(matmul, 0, netoutput, 0);

  auto graph = gb.Build();
  EXPECT_NE(graph, nullptr);
  EXPECT_EQ(graph->GetNodeCount(), 6);
  EXPECT_EQ(graph->GetEdgeCount(), 5);
  EXPECT_TRUE(Connected(graph.get(), "data0", 0, "trans0", 0));
  EXPECT_TRUE(Connected(graph.get(), "data1", 0, "trans1", 0));
  EXPECT_TRUE(Connected(graph.get(), "trans0", 0, "matmul", 0));
  EXPECT_TRUE(Connected(graph.get(), "trans1", 0, "matmul", 1));
  EXPECT_TRUE(Connected(graph.get(), "matmul", 0, "netoutput", 0));
  EXPECT_EQ(GetNodeNamesByTopo(graph.get()),
            std::vector<std::string>({"data0", "data1", "trans0", "trans1", "matmul", "netoutput"}));
}

TEST_F(ExecuteGraphBuilderUt, AddNullString) {
  ExecuteGraphBuilder gb;
  EXPECT_EQ(gb.AddString(nullptr), kInvalidOffset);
  bool exists = false;
  EXPECT_EQ(gb.AddString(nullptr, exists), kInvalidOffset);
  EXPECT_FALSE(exists);
  exists = true;
  EXPECT_EQ(gb.AddString(nullptr, exists), kInvalidOffset);
  EXPECT_TRUE(exists);
}

TEST_F(ExecuteGraphBuilderUt, AddExistsString) {
  ExecuteGraphBuilder gb;
  bool exists;
  auto offset_a = gb.AddString("a", exists);
  EXPECT_FALSE(exists);
  EXPECT_NE(offset_a, kInvalidOffset);
  EXPECT_EQ(offset_a, gb.AddString("a", exists));
  EXPECT_TRUE(exists);

  auto offset_b = gb.AddString("bbb");
  EXPECT_NE(offset_b, kInvalidOffset);
  EXPECT_EQ(offset_b, gb.AddString("bbb"));

  EXPECT_NE(offset_b, offset_a);
}

TEST_F(ExecuteGraphBuilderUt, AddNullOpDef) {
  ExecuteGraphBuilder gb;
  EXPECT_EQ(gb.AddOpDef(nullptr), kInvalidOffset);
  EXPECT_EQ(gb.AddOpDef(nullptr), kInvalidOffset);
  EXPECT_EQ(gb.AddOpDef(nullptr), kInvalidOffset);
}

TEST_F(ExecuteGraphBuilderUt, AddDuplicatedOpDef) {
  ExecuteGraphBuilder gb;
  auto offset_aaa = gb.AddOpDef("aaa");
  EXPECT_NE(offset_aaa, kInvalidOffset);
  auto offset_bbb = gb.AddOpDef("bbb");
  EXPECT_NE(offset_bbb, kInvalidOffset);
  EXPECT_NE(offset_aaa, offset_bbb);

  EXPECT_EQ(offset_aaa, gb.AddOpDef("aaa"));
  EXPECT_EQ(offset_bbb, gb.AddOpDef("bbb"));
}

TEST_F(ExecuteGraphBuilderUt, GetOpDefOk) {
  ExecuteGraphBuilder gb;
  auto offset_aaa = gb.AddOpDef("aaa");
  auto offset_bbb = gb.AddOpDef("bbb");

  EXPECT_NE(gb.GetOpDefBuilder(offset_aaa), nullptr);
  EXPECT_NE(gb.GetOpDefBuilder(offset_bbb), nullptr);
}

TEST_F(ExecuteGraphBuilderUt, GetOpDefNotExists) {
  ExecuteGraphBuilder gb;
  auto offset_aaa = gb.AddOpDef("aaa");
  auto offset_bbb = gb.AddOpDef("bbb");

  EXPECT_EQ(gb.GetOpDefBuilder(10241), nullptr);
  EXPECT_EQ(gb.GetOpDefBuilder(kInvalidOffset), nullptr);
}

TEST_F(ExecuteGraphBuilderUt, GetNodeBuilderOk) {
  ExecuteGraphBuilder gb;
  auto n1 = gb.AddNode();
  auto n2 = gb.AddNode();
  EXPECT_NE(n1, kInvalidOffset);
  EXPECT_NE(n2, kInvalidOffset);
  EXPECT_NE(gb.GetNodeBuilder(n1), nullptr);
  EXPECT_NE(gb.GetNodeBuilder(n2), nullptr);
  EXPECT_EQ(gb.GetNodeBuilder(1024), nullptr);
  EXPECT_EQ(gb.GetNodeBuilder(kInvalidOffset), nullptr);
}

TEST_F(ExecuteGraphBuilderUt, GetNodeBuilderNotExists) {
  ExecuteGraphBuilder gb;
  gb.AddNode();
  gb.AddNode();
  EXPECT_EQ(gb.GetNodeBuilder(1024), nullptr);
  EXPECT_EQ(gb.GetNodeBuilder(kInvalidOffset), nullptr);
}

TEST_F(ExecuteGraphBuilderUt, AddEdgeDataAndCtrl) {
  ExecuteGraphBuilder gb;
  auto reg = RegOpDefs(gb);

  auto data_offset = AddNode(gb, "data", reg.data, {"x"}, {"y"});
  auto conv2d_offset = AddNode(gb, "conv2d", reg.conv2d, {"x", "filter"}, {"y"});

  gb.AddEdge(data_offset, 0, conv2d_offset, -1);

  auto graph = gb.Build();
  EXPECT_EQ(graph, nullptr);
}

TEST_F(ExecuteGraphBuilderUt, AddEdgeCtrlAndData) {
  ExecuteGraphBuilder gb;
  auto reg = RegOpDefs(gb);

  auto data_offset = AddNode(gb, "data", reg.data, {"x"}, {"y"});
  auto conv2d_offset = AddNode(gb, "conv2d", reg.conv2d, {"x", "filter"}, {"y"});

  gb.AddEdge(data_offset, -1, conv2d_offset, 0);

  auto graph = gb.Build();
  EXPECT_EQ(graph, nullptr);
}

TEST_F(ExecuteGraphBuilderUt, AddEdgeInvalidOffset) {
  ExecuteGraphBuilder gb;
  auto reg = RegOpDefs(gb);

  auto data_offset = AddNode(gb, "data", reg.data, {"x"}, {"y"});
  auto conv2d_offset = AddNode(gb, "conv2d", reg.conv2d, {"x", "filter"}, {"y"});

  gb.AddEdge(kInvalidOffset, 0, conv2d_offset, 0);

  auto graph = gb.Build();
  EXPECT_EQ(graph, nullptr);
}

TEST_F(ExecuteGraphBuilderUt, AddEdgeNotExistsSrcIndex) {
  ExecuteGraphBuilder gb;
  auto reg = RegOpDefs(gb);

  auto data_offset = AddNode(gb, "data", reg.data, {"x"}, {"y"});
  auto conv2d_offset = AddNode(gb, "conv2d", reg.conv2d, {"x", "filter"}, {"y"});

  gb.AddEdge(data_offset, 1, conv2d_offset, 0);

  auto graph = gb.Build();
  EXPECT_EQ(graph, nullptr);
}

TEST_F(ExecuteGraphBuilderUt, AddEdgeNotExistsDstIndex) {
  ExecuteGraphBuilder gb;
  auto reg = RegOpDefs(gb);

  auto data_offset = AddNode(gb, "data", reg.data, {"x"}, {"y"});
  auto conv2d_offset = AddNode(gb, "conv2d", reg.conv2d, {"x", "filter"}, {"y"});

  gb.AddEdge(data_offset, 0, conv2d_offset, 2);

  auto graph = gb.Build();
  EXPECT_EQ(graph, nullptr);
}

TEST_F(ExecuteGraphBuilderUt, AddCtrlEdgeOk) {
  ExecuteGraphBuilder gb;
  auto reg = RegOpDefs(gb);

  auto data_offset = AddNode(gb, "data", reg.data, {"x"}, {"y"});
  auto const_offset = AddNode(gb, "const", reg.const_offset, {}, {"y"});
  auto conv2d_offset = AddNode(gb, "conv2d", reg.conv2d, {"x", "filter"}, {"y"});
  auto netoutput_offset = AddNode(gb, "netoutput", reg.netoutput, {"x1"}, {"y1"});

  gb.GetNodeBuilder(conv2d_offset)->SetAttr(0, 0).SetAttr(1, 1);

  gb.AddControlEdge(data_offset, const_offset)
      .AddControlEdge(data_offset, conv2d_offset)
      .AddEdge(const_offset, -1, netoutput_offset, -1)
      .AddEdge(conv2d_offset, 0, netoutput_offset, 0);

  auto graph = gb.Build();
  EXPECT_NE(graph, nullptr);
  EXPECT_EQ(graph->GetEdgeCount(), 4);
  EXPECT_TRUE(Connected(graph.get(), "data", -1, "const", -1));
  EXPECT_TRUE(Connected(graph.get(), "data", -1, "conv2d", -1));
  EXPECT_TRUE(Connected(graph.get(), "const", -1, "netoutput", -1));
  EXPECT_TRUE(Connected(graph.get(), "conv2d", 0, "netoutput", 0));
}

/**
 *   netoutput
 *     | |
 *    split
 *   /    \
 * data0   const0
 */
TEST_F(ExecuteGraphBuilderUt, MultipleEdgesTwoNodes) {
  ExecuteGraphBuilder gb;
  auto reg = RegOpDefs(gb);

  auto data_offset = AddNode(gb, "data0", reg.data, {"x"}, {"y"});
  auto const_offset = AddNode(gb, "const0", reg.const_offset, {}, {"y"});
  auto split_offset = AddNode(gb, "split", reg.split, {"split_dim", "x"}, {"y1", "y2"});
  auto netoutput_offset = AddNode(gb, "netoutput", reg.netoutput, {"x1", "x2"}, {"y1", "y2"});

  gb.GetNodeBuilder(split_offset)->SetAttr(0, 0);

  gb.AddEdge(data_offset, 0, split_offset, 0)
      .AddEdge(const_offset, 0, split_offset, 1)
      .AddEdge(split_offset, 0, netoutput_offset, 0)
      .AddEdge(split_offset, 1, netoutput_offset, 1);

  auto graph = gb.Build();
  EXPECT_NE(graph, nullptr);

  EXPECT_EQ(graph->GetEdgeCount(), 4);
  EXPECT_TRUE(Connected(graph.get(), "data0", 0, "split", 0));
  EXPECT_TRUE(Connected(graph.get(), "const0", 0, "split", 1));
  EXPECT_TRUE(Connected(graph.get(), "split", 1, "netoutput", 1));
  EXPECT_TRUE(Connected(graph.get(), "split", 0, "netoutput", 0));
}

TEST_F(ExecuteGraphBuilderUt, AddDuplicatedEdgesOk) {
  ExecuteGraphBuilder gb;
  auto reg = RegOpDefs(gb);

  auto data_offset = AddNode(gb, "data", reg.data, {"x"}, {"y"});
  auto const_offset = AddNode(gb, "const", reg.const_offset, {}, {"y"});
  auto conv2d_offset = AddNode(gb, "conv2d", reg.conv2d, {"x", "filter"}, {"y"});
  auto netoutput_offset = AddNode(gb, "netoutput", reg.netoutput, {"x1"}, {"y1"});

  gb.GetNodeBuilder(conv2d_offset)->SetAttr(0, 0).SetAttr(1, 1);

  gb.AddEdge(data_offset, -1, conv2d_offset, -1)
      .AddEdge(conv2d_offset, 0, netoutput_offset, 0)
      .AddControlEdge(data_offset, conv2d_offset)
      .AddEdge(conv2d_offset, 0, netoutput_offset, 0);

  auto graph = gb.Build();
  EXPECT_NE(graph, nullptr);
  EXPECT_EQ(graph->GetEdgeCount(), 2);
  EXPECT_TRUE(Connected(graph.get(), "data", -1, "conv2d", -1));
  EXPECT_TRUE(Connected(graph.get(), "conv2d", 0, "netoutput", 0));
}
}  // namespace ge